datachain 0.16.1__py3-none-any.whl → 0.16.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

datachain/func/array.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from collections.abc import Sequence
2
- from typing import Any, Union
2
+ from typing import Any, Optional, Union
3
3
 
4
4
  from datachain.sql.functions import array
5
5
 
@@ -178,6 +178,61 @@ def contains(arr: Union[str, Sequence, Func], elem: Any) -> Func:
178
178
  return Func("contains", inner=inner, cols=cols, args=args, result_type=int)
179
179
 
180
180
 
181
+ def get_element(arg: Union[str, Sequence, Func], index: int) -> Func:
182
+ """
183
+ Returns the element at the given index from the array.
184
+ If the index is out of bounds, it returns None or columns default value.
185
+
186
+ Args:
187
+ arg (str | Sequence | Func): Array to get the element from.
188
+ If a string is provided, it is assumed to be the name of the array column.
189
+ If a sequence is provided, it is assumed to be an array of values.
190
+ If a Func is provided, it is assumed to be a function returning an array.
191
+ index (int): Index of the element to get from the array.
192
+
193
+ Returns:
194
+ Func: A Func object that represents the array get_element function.
195
+
196
+ Example:
197
+ ```py
198
+ dc.mutate(
199
+ first_el=func.array.get_element("signal.values", 0),
200
+ second_el=func.array.get_element([1, 2, 3, 4, 5], 1),
201
+ )
202
+ ```
203
+
204
+ Note:
205
+ - Result column will always be the same type as the elements of the array.
206
+ """
207
+
208
+ def type_from_args(arr, _):
209
+ if isinstance(arr, list):
210
+ try:
211
+ return type(arr[0])
212
+ except IndexError:
213
+ return str # if the array is empty, return str as default type
214
+ return None
215
+
216
+ cols: Optional[Union[str, Sequence, Func]]
217
+ args: Union[str, Sequence, Func, int]
218
+
219
+ if isinstance(arg, (str, Func)):
220
+ cols = [arg]
221
+ args = [index]
222
+ else:
223
+ cols = None
224
+ args = [arg, index]
225
+
226
+ return Func(
227
+ "get_element",
228
+ inner=array.get_element,
229
+ cols=cols,
230
+ args=args,
231
+ from_array=True,
232
+ type_from_args=type_from_args,
233
+ )
234
+
235
+
181
236
  def sip_hash_64(arg: Union[str, Sequence]) -> Func:
182
237
  """
183
238
  Computes the SipHash-64 hash of the array.
datachain/func/func.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import inspect
2
2
  from collections.abc import Sequence
3
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args, get_origin
4
4
 
5
5
  from sqlalchemy import BindParameter, Case, ColumnElement, Integer, cast, desc
6
6
  from sqlalchemy.sql import func as sa_func
@@ -36,7 +36,9 @@ class Func(Function):
36
36
  args: Optional[Sequence[Any]] = None,
37
37
  kwargs: Optional[dict[str, Any]] = None,
38
38
  result_type: Optional["DataType"] = None,
39
+ type_from_args: Optional[Callable[..., "DataType"]] = None,
39
40
  is_array: bool = False,
41
+ from_array: bool = False,
40
42
  is_window: bool = False,
41
43
  window: Optional["Window"] = None,
42
44
  label: Optional[str] = None,
@@ -47,7 +49,9 @@ class Func(Function):
47
49
  self.args = args or []
48
50
  self.kwargs = kwargs or {}
49
51
  self.result_type = result_type
52
+ self.type_from_args = type_from_args
50
53
  self.is_array = is_array
54
+ self.from_array = from_array
51
55
  self.is_window = is_window
52
56
  self.window = window
53
57
  self.col_label = label
@@ -66,7 +70,9 @@ class Func(Function):
66
70
  self.args,
67
71
  self.kwargs,
68
72
  self.result_type,
73
+ self.type_from_args,
69
74
  self.is_array,
75
+ self.from_array,
70
76
  self.is_window,
71
77
  window,
72
78
  self.col_label,
@@ -101,6 +107,20 @@ class Func(Function):
101
107
  "Columns must have the same type to infer result type",
102
108
  )
103
109
 
110
+ if self.from_array:
111
+ if get_origin(col_type) is list:
112
+ col_args = get_args(col_type)
113
+ if len(col_args) != 1:
114
+ raise DataChainColumnError(
115
+ str(self),
116
+ "Array column must have a single type argument",
117
+ )
118
+ return col_args[0]
119
+ raise DataChainColumnError(
120
+ str(self),
121
+ "Array column must be of type list",
122
+ )
123
+
104
124
  return list[col_type] if self.is_array else col_type # type: ignore[valid-type]
105
125
 
106
126
  def __add__(self, other: Union[ColT, float]) -> "Func":
@@ -339,7 +359,9 @@ class Func(Function):
339
359
  self.args,
340
360
  self.kwargs,
341
361
  self.result_type,
362
+ self.type_from_args,
342
363
  self.is_array,
364
+ self.from_array,
343
365
  self.is_window,
344
366
  self.window,
345
367
  label,
@@ -368,6 +390,15 @@ class Func(Function):
368
390
  if signals_schema and (col_type := self._db_col_type(signals_schema)):
369
391
  return col_type
370
392
 
393
+ if (
394
+ self.type_from_args
395
+ and (self.cols is None or self.cols == [])
396
+ and self.args is not None
397
+ and len(self.args) > 0
398
+ and (result_type := self.type_from_args(*self.args)) is not None
399
+ ):
400
+ return result_type
401
+
371
402
  raise DataChainColumnError(
372
403
  str(self),
373
404
  "Column name is required to infer result type",
@@ -127,9 +127,11 @@ def read_database(
127
127
  ```
128
128
 
129
129
  Notes:
130
- This function works with a variety of databases — including, but not limited to,
131
- SQLite, DuckDB, PostgreSQL, and Snowflake, provided the appropriate driver is
132
- installed.
130
+ - This function works with a variety of databases — including,
131
+ but not limited to, SQLite, DuckDB, PostgreSQL, and Snowflake,
132
+ provided the appropriate driver is installed.
133
+ - This call is blocking, and will execute the query and return once the
134
+ results are saved.
133
135
  """
134
136
  from datachain.lib.dc.records import read_records
135
137
 
@@ -37,9 +37,13 @@ def read_records(
37
37
  import datachain as dc
38
38
  single_record = dc.read_records(dc.DEFAULT_FILE_RECORD)
39
39
  ```
40
+
41
+ Notes:
42
+ This call blocks until all records are inserted.
40
43
  """
41
- from datachain.query.dataset import adjust_outputs, get_col_types
44
+ from datachain.query.dataset import INSERT_BATCH_SIZE, adjust_outputs, get_col_types
42
45
  from datachain.sql.types import SQLType
46
+ from datachain.utils import batched
43
47
 
44
48
  from .datasets import read_dataset
45
49
 
@@ -89,6 +93,7 @@ def read_records(
89
93
  {c.name: c.type for c in columns if isinstance(c.type, SQLType)},
90
94
  )
91
95
  records = (adjust_outputs(warehouse, record, col_types) for record in to_insert)
92
- warehouse.insert_rows(table, records)
96
+ for chunk in batched(records, INSERT_BATCH_SIZE):
97
+ warehouse.insert_rows(table, chunk)
93
98
  warehouse.insert_rows_done(table)
94
99
  return read_dataset(name=dsr.name, session=session, settings=settings)
@@ -48,6 +48,16 @@ class contains(GenericFunction): # noqa: N801
48
48
  inherit_cache = True
49
49
 
50
50
 
51
+ class get_element(GenericFunction): # noqa: N801
52
+ """
53
+ Returns the element at the given index in the array.
54
+ """
55
+
56
+ package = "array"
57
+ name = "get_element"
58
+ inherit_cache = True
59
+
60
+
51
61
  class sip_hash_64(GenericFunction): # noqa: N801
52
62
  """
53
63
  Computes the SipHash-64 hash of the array.
@@ -63,4 +73,5 @@ compiler_not_implemented(cosine_distance)
63
73
  compiler_not_implemented(euclidean_distance)
64
74
  compiler_not_implemented(length)
65
75
  compiler_not_implemented(contains)
76
+ compiler_not_implemented(get_element)
66
77
  compiler_not_implemented(sip_hash_64)
@@ -88,6 +88,7 @@ def setup():
88
88
  compiles(sql_path.file_ext, "sqlite")(compile_path_file_ext)
89
89
  compiles(array.length, "sqlite")(compile_array_length)
90
90
  compiles(array.contains, "sqlite")(compile_array_contains)
91
+ compiles(array.get_element, "sqlite")(compile_array_get_element)
91
92
  compiles(string.length, "sqlite")(compile_string_length)
92
93
  compiles(string.split, "sqlite")(compile_string_split)
93
94
  compiles(string.regexp_replace, "sqlite")(compile_string_regexp_replace)
@@ -270,6 +271,13 @@ def register_user_defined_sql_functions() -> None:
270
271
 
271
272
  _registered_function_creators["string_functions"] = create_string_functions
272
273
 
274
+ def create_array_functions(conn):
275
+ conn.create_function(
276
+ "json_array_get_element", 2, py_json_array_get_element, deterministic=True
277
+ )
278
+
279
+ _registered_function_creators["array_functions"] = create_array_functions
280
+
273
281
  has_json_extension = functions_exist(["json_array_length", "json_array_contains"])
274
282
  if not has_json_extension:
275
283
 
@@ -438,6 +446,20 @@ def py_json_array_contains(arr, value, is_json):
438
446
  return value in orjson.loads(arr)
439
447
 
440
448
 
449
+ def py_json_array_get_element(val, idx):
450
+ arr = orjson.loads(val)
451
+ try:
452
+ return arr[idx]
453
+ except IndexError:
454
+ return None
455
+
456
+
457
+ def compile_array_get_element(element, compiler, **kwargs):
458
+ return compiler.process(
459
+ func.json_array_get_element(*element.clauses.clauses), **kwargs
460
+ )
461
+
462
+
441
463
  def compile_array_length(element, compiler, **kwargs):
442
464
  return compiler.process(func.json_array_length(*element.clauses.clauses), **kwargs)
443
465
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datachain
3
- Version: 0.16.1
3
+ Version: 0.16.3
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License-Expression: Apache-2.0
@@ -23,7 +23,7 @@ Requires-Dist: tqdm
23
23
  Requires-Dist: numpy<3,>=1
24
24
  Requires-Dist: pandas>=2.0.0
25
25
  Requires-Dist: packaging
26
- Requires-Dist: pyarrow
26
+ Requires-Dist: pyarrow<20
27
27
  Requires-Dist: typing-extensions
28
28
  Requires-Dist: python-dateutil>=2
29
29
  Requires-Dist: attrs>=21.3.0
@@ -56,10 +56,10 @@ datachain/fs/reference.py,sha256=A8McpXF0CqbXPqanXuvpKu50YLB3a2ZXA3YAPxtBXSM,914
56
56
  datachain/fs/utils.py,sha256=s-FkTOCGBk-b6TT3toQH51s9608pofoFjUSTc1yy7oE,825
57
57
  datachain/func/__init__.py,sha256=CjNLHfJkepdXdRZ6HjJBjNSIjOeFMuMkwPDaPUrM75g,1270
58
58
  datachain/func/aggregate.py,sha256=UfxENlw56Qv3UEkj2sZ-JZHmr9q8Rnic9io9_63gF-E,10942
59
- datachain/func/array.py,sha256=O784_uwmaP5CjZX4VSF4RmS8cmpaForQc8zASxHJB6A,6717
59
+ datachain/func/array.py,sha256=OmmjdK5AQyBXa_7NXFUPY_m3lFlRK4Um4J9NCtYwvak,8394
60
60
  datachain/func/base.py,sha256=wA0sBQAVyN9LPxoo7Ox83peS0zUVnyuKxukwAcjGLfY,534
61
61
  datachain/func/conditional.py,sha256=HkNamQr9dLyIMDEbIeO6CZR0emQoDqeaWrZ1fECod4M,8062
62
- datachain/func/func.py,sha256=k8z5tIiabEOPymYWGfz4O7z1qS6zBZnVYRPp_58iU7c,16192
62
+ datachain/func/func.py,sha256=jzEvnc2iN0BAly-uzxhaoMntL_xF4j94DrFuRi2ADSw,17321
63
63
  datachain/func/numeric.py,sha256=gMe1Ks0dqQKHkjcpvj7I5S-neECzQ_gltPQLNoaWOyo,5632
64
64
  datachain/func/path.py,sha256=mqN_mfkwv44z2II7DMTp_fGGw95hmTCNls_TOFNpr4k,3155
65
65
  datachain/func/random.py,sha256=pENOLj9rSmWfGCnOsUIaCsVC5486zQb66qfQvXaz9Z4,452
@@ -96,7 +96,7 @@ datachain/lib/convert/unflatten.py,sha256=ysMkstwJzPMWUlnxn-Z-tXJR3wmhjHeSN_P-sD
96
96
  datachain/lib/convert/values_to_tuples.py,sha256=j5yZMrVUH6W7b-7yUvdCTGI7JCUAYUOzHUGPoyZXAB0,4360
97
97
  datachain/lib/dc/__init__.py,sha256=HD0NYrdy44u6kkpvgGjJcvGz-UGTHui2azghcT8ZUg0,838
98
98
  datachain/lib/dc/csv.py,sha256=asWPAxhMgIoLAdD2dObDlnGL8CTSD3TAuFuM4ci89bQ,4374
99
- datachain/lib/dc/database.py,sha256=gYKh1iO5hOWMPFTU1vZC5kOXkJzVse14TYTWE4_1iEA,5940
99
+ datachain/lib/dc/database.py,sha256=g5M6NjYR1T0vKte-abV-3Ejnm-HqxTIMir5cRi_SziE,6051
100
100
  datachain/lib/dc/datachain.py,sha256=36J8QIB04hKKumQgLvHNTC94Pd7G2yE4slZ9RfwI9zw,76980
101
101
  datachain/lib/dc/datasets.py,sha256=u6hlz0Eodh_s39TOW6kz0VIL3nGfadqu8FLoWqDxSJs,6890
102
102
  datachain/lib/dc/hf.py,sha256=PJl2wiLjdRsMz0SYbLT-6H8b-D5i2WjeH7li8HHOk_0,2145
@@ -104,7 +104,7 @@ datachain/lib/dc/json.py,sha256=ZUThPDAaP2gBFIL5vsQTwKBcuN_dhvC_O44wdDv0jEc,2683
104
104
  datachain/lib/dc/listings.py,sha256=2na9v63xO1vPUNaoBSzA-TSN49V7zQAb-4iS1wOPLFE,1029
105
105
  datachain/lib/dc/pandas.py,sha256=ObueUXDUFKJGu380GmazdG02ARpKAHPhSaymfmOH13E,1489
106
106
  datachain/lib/dc/parquet.py,sha256=zYcSgrWwyEDW9UxGUSVdIVsCu15IGEf0xL8KfWQqK94,1782
107
- datachain/lib/dc/records.py,sha256=Z6EWy6c6hf87cWiDlQduvrDgOHMLwqF22g-XksOnXsU,2884
107
+ datachain/lib/dc/records.py,sha256=J1I69J2gFIBjRTGr2LG-5qn_rTVzRLcr2y3tVDrmHdg,3068
108
108
  datachain/lib/dc/storage.py,sha256=QLf3-xMV2Gmy3AA8qF9WqAsb7R8Rk87l4s5hBoiCH98,5285
109
109
  datachain/lib/dc/utils.py,sha256=VawOAlJSvAtZbsMg33s5tJe21TRx1Km3QggI1nN6tnw,3984
110
110
  datachain/lib/dc/values.py,sha256=cBQubhmPNEDMJldUXzGh-UKbdim4P6O2B91Gp39roKw,1389
@@ -138,22 +138,22 @@ datachain/sql/default/__init__.py,sha256=XQ2cEZpzWiABqjV-6yYHUBGI9vN_UHxbxZENESm
138
138
  datachain/sql/default/base.py,sha256=QD-31C6JnyOXzogyDx90sUhm7QvgXIYpeHEASH84igU,628
139
139
  datachain/sql/functions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
140
140
  datachain/sql/functions/aggregate.py,sha256=3AQdA8YHPFdtCEfwZKQXTT8SlQWdG9gD5PBtGN3Odqs,944
141
- datachain/sql/functions/array.py,sha256=ZXPHNQxG-t5tl7Gr0nubu4oc-9fBZmX5q03z9ljWg-I,1443
141
+ datachain/sql/functions/array.py,sha256=LkVDq1Iu3pF2vz9opjxcS0oSvEepoLNYVXcmnOzsmY0,1679
142
142
  datachain/sql/functions/conditional.py,sha256=q7YUKfunXeEldXaxgT-p5pUTcOEVU_tcQ2BJlquTRPs,207
143
143
  datachain/sql/functions/numeric.py,sha256=BK2KCiPSgM2IveCq-9M_PG3CtPBlztaS9TTn1LGzyLs,1250
144
144
  datachain/sql/functions/path.py,sha256=zixpERotTFP6LZ7I4TiGtyRA8kXOoZmH1yzH9oRW0mg,1294
145
145
  datachain/sql/functions/random.py,sha256=vBwEEj98VH4LjWixUCygQ5Bz1mv1nohsCG0-ZTELlVg,271
146
146
  datachain/sql/functions/string.py,sha256=E-T9OIzUR-GKaLgjZsEtg5CJrY_sLf1lt1awTvY7w2w,1426
147
147
  datachain/sql/sqlite/__init__.py,sha256=TAdJX0Bg28XdqPO-QwUVKy8rg78cgMileHvMNot7d04,166
148
- datachain/sql/sqlite/base.py,sha256=N-cQT0Hpu9ROWe4OiKlkkn_YP1NKCRZZ3xSfTzpyaDA,19651
148
+ datachain/sql/sqlite/base.py,sha256=SktpdtyZmxG9ip_UX_0WL3YxP9o66CYTeMfriRrZzaE,20281
149
149
  datachain/sql/sqlite/types.py,sha256=cH6oge2E_YWFy22wY-txPJH8gxoQFSpCthtZR8PZjpo,1849
150
150
  datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
151
151
  datachain/toolkit/__init__.py,sha256=eQ58Q5Yf_Fgv1ZG0IO5dpB4jmP90rk8YxUWmPc1M2Bo,68
152
152
  datachain/toolkit/split.py,sha256=ktGWzY4kyzjWyR86dhvzw-Zhl0lVk_LOX3NciTac6qo,2914
153
153
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
154
- datachain-0.16.1.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
155
- datachain-0.16.1.dist-info/METADATA,sha256=9YPqP6Sthuf_fuxFX3miQyp9MEjRq8j2DqubLXvZg0k,11328
156
- datachain-0.16.1.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
157
- datachain-0.16.1.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
158
- datachain-0.16.1.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
159
- datachain-0.16.1.dist-info/RECORD,,
154
+ datachain-0.16.3.dist-info/licenses/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
155
+ datachain-0.16.3.dist-info/METADATA,sha256=3XVhBEDIISei-EeF3LEpAimkbJ2vQ1yHAEKWZqgfVbs,11331
156
+ datachain-0.16.3.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
157
+ datachain-0.16.3.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
158
+ datachain-0.16.3.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
159
+ datachain-0.16.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (79.0.0)
2
+ Generator: setuptools (80.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5