datachain 0.2.17__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -1,23 +1,18 @@
1
- from datetime import datetime
1
+ from decimal import Decimal
2
2
  from typing import Any
3
3
 
4
- from sqlalchemy import ARRAY, JSON, Boolean, DateTime, Float, Integer, String
4
+ from sqlalchemy import ColumnElement
5
5
 
6
- from datachain.data_storage.sqlite import Column
7
6
 
8
- SQL_TO_PYTHON = {
9
- String: str,
10
- Integer: int,
11
- Float: float,
12
- Boolean: bool,
13
- DateTime: datetime,
14
- ARRAY: list,
15
- JSON: dict,
16
- }
7
+ def sql_to_python(args_map: dict[str, ColumnElement]) -> dict[str, Any]:
8
+ res = {}
9
+ for name, sql_exp in args_map.items():
10
+ try:
11
+ type_ = sql_exp.type.python_type
12
+ if type_ == Decimal:
13
+ type_ = float
14
+ except NotImplementedError:
15
+ type_ = str
16
+ res[name] = type_
17
17
 
18
-
19
- def sql_to_python(args_map: dict[str, Column]) -> dict[str, Any]:
20
- return {
21
- k: SQL_TO_PYTHON.get(type(v.type), str) # type: ignore[union-attr]
22
- for k, v in args_map.items()
23
- }
18
+ return res
datachain/lib/dc.py CHANGED
@@ -20,8 +20,10 @@ import pandas as pd
20
20
  import sqlalchemy
21
21
  from pydantic import BaseModel, create_model
22
22
  from sqlalchemy.sql.functions import GenericFunction
23
+ from sqlalchemy.sql.sqltypes import NullType
23
24
 
24
25
  from datachain import DataModel
26
+ from datachain.lib.convert.python_to_sql import python_to_sql
25
27
  from datachain.lib.convert.values_to_tuples import values_to_tuples
26
28
  from datachain.lib.data_model import DataType
27
29
  from datachain.lib.dataset_info import DatasetInfo
@@ -110,6 +112,11 @@ class DatasetMergeError(DataChainParamsError): # noqa: D101
110
112
  super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")
111
113
 
112
114
 
115
+ class DataChainColumnError(DataChainParamsError): # noqa: D101
116
+ def __init__(self, col_name, msg): # noqa: D107
117
+ super().__init__(f"Error for column {col_name}: {msg}")
118
+
119
+
113
120
  OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]
114
121
 
115
122
 
@@ -225,6 +232,17 @@ class DataChain(DatasetQuery):
225
232
  """Get schema of the chain."""
226
233
  return self._effective_signals_schema.values
227
234
 
235
+ def column(self, name: str) -> Column:
236
+ """Returns Column instance with a type if name is found in current schema,
237
+ otherwise raises an exception.
238
+ """
239
+ name_path = name.split(".")
240
+ for path, type_, _, _ in self.signals_schema.get_flat_tree():
241
+ if path == name_path:
242
+ return Column(name, python_to_sql(type_))
243
+
244
+ raise ValueError(f"Column with name {name} not found in the schema")
245
+
228
246
  def print_schema(self) -> None:
229
247
  """Print schema of the chain."""
230
248
  self._effective_signals_schema.print_tree()
@@ -829,6 +847,12 @@ class DataChain(DatasetQuery):
829
847
  )
830
848
  ```
831
849
  """
850
+ for col_name, expr in kwargs.items():
851
+ if not isinstance(expr, Column) and isinstance(expr.type, NullType):
852
+ raise DataChainColumnError(
853
+ col_name, f"Cannot infer type with expression {expr}"
854
+ )
855
+
832
856
  mutated = {}
833
857
  schema = self.signals_schema
834
858
  for name, value in kwargs.items():
@@ -1,16 +1,17 @@
1
1
  from sqlalchemy.sql.expression import func
2
2
 
3
- from . import path, string
3
+ from . import array, path, string
4
+ from .array import avg
4
5
  from .conditional import greatest, least
5
6
  from .random import rand
6
7
 
7
8
  count = func.count
8
9
  sum = func.sum
9
- avg = func.avg
10
10
  min = func.min
11
11
  max = func.max
12
12
 
13
13
  __all__ = [
14
+ "array",
14
15
  "avg",
15
16
  "count",
16
17
  "func",
@@ -44,7 +44,15 @@ class sip_hash_64(GenericFunction): # noqa: N801
44
44
  inherit_cache = True
45
45
 
46
46
 
47
+ class avg(GenericFunction): # noqa: N801
48
+ type = Float()
49
+ package = "array"
50
+ name = "avg"
51
+ inherit_cache = True
52
+
53
+
47
54
  compiler_not_implemented(cosine_distance)
48
55
  compiler_not_implemented(euclidean_distance)
49
56
  compiler_not_implemented(length)
50
57
  compiler_not_implemented(sip_hash_64)
58
+ compiler_not_implemented(avg)
@@ -78,6 +78,7 @@ def setup():
78
78
  compiles(conditional.least, "sqlite")(compile_least)
79
79
  compiles(Values, "sqlite")(compile_values)
80
80
  compiles(random.rand, "sqlite")(compile_rand)
81
+ compiles(array.avg, "sqlite")(compile_avg)
81
82
 
82
83
  if load_usearch_extension(sqlite3.connect(":memory:")):
83
84
  compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext)
@@ -349,6 +350,10 @@ def compile_rand(element, compiler, **kwargs):
349
350
  return compiler.process(func.random(), **kwargs)
350
351
 
351
352
 
353
+ def compile_avg(element, compiler, **kwargs):
354
+ return compiler.process(func.avg(*element.clauses.clauses), **kwargs)
355
+
356
+
352
357
  def load_usearch_extension(conn) -> bool:
353
358
  try:
354
359
  # usearch is part of the vector optional dependencies
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: datachain
3
- Version: 0.2.17
3
+ Version: 0.2.18
4
4
  Summary: Wrangle unstructured AI data at scale
5
5
  Author-email: Dmitry Petrov <support@dvc.org>
6
6
  License: Apache-2.0
@@ -42,7 +42,7 @@ datachain/lib/arrow.py,sha256=R8wDUDEa-5hYjI3HW9cqvOYYJpeeah5lbhFIL3gkmcE,4915
42
42
  datachain/lib/clip.py,sha256=16u4b_y2Y15nUS2UN_8ximMo6r_-_4IQpmct2ol-e-g,5730
43
43
  datachain/lib/data_model.py,sha256=qfTtQNncS5pt9SvXdMEa5kClniaT6XBGBfO7onEz2TI,1632
44
44
  datachain/lib/dataset_info.py,sha256=lONGr71ozo1DS4CQEhnpKORaU4qFb6Ketv8Xm8CVm2U,2188
45
- datachain/lib/dc.py,sha256=bZx7VJ389SJ5gRTkckFD044LHq_hOgHqvhTD7gJoBZY,56963
45
+ datachain/lib/dc.py,sha256=F2DrvBLxsLDHY7wDVzMFj_-IRscDxb_STTRMqd0gmyw,57971
46
46
  datachain/lib/file.py,sha256=MCklths3w9SgQTR0LACnDohfGdEc3t30XD0qNq1oTlI,12000
47
47
  datachain/lib/image.py,sha256=TgYhRhzd4nkytfFMeykQkPyzqb5Le_-tU81unVMPn4Q,2328
48
48
  datachain/lib/meta_formats.py,sha256=jlSYWRUeDMjun_YCsQ2JxyaDJpEpokzHDPmKUAoCXnU,7034
@@ -60,7 +60,7 @@ datachain/lib/webdataset_laion.py,sha256=PQP6tQmUP7Xu9fPuAGK1JDBYA6T5UufYMUTGaxg
60
60
  datachain/lib/convert/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
61
61
  datachain/lib/convert/flatten.py,sha256=YMoC00BqEy3zSpvCp6Q0DfxihuPmgjUJj1g2cesWGPs,1790
62
62
  datachain/lib/convert/python_to_sql.py,sha256=4gplGlr_Kg-Z40OpJUzJiarDWj7pwbUOk-dPOYYCJ9Q,2629
63
- datachain/lib/convert/sql_to_python.py,sha256=HK414fexSQ4Ur-OY7_pKvDKEGdtos1CeeAFa4RxH4nU,532
63
+ datachain/lib/convert/sql_to_python.py,sha256=lGnKzSF_tz9Y_5SSKkrIU95QEjpcDzvOxIRkEKTQag0,443
64
64
  datachain/lib/convert/unflatten.py,sha256=Ogvh_5wg2f38_At_1lN0D_e2uZOOpYEvwvB2xdq56Tw,2012
65
65
  datachain/lib/convert/values_to_tuples.py,sha256=aVoHWMOUGLAiS6_BBwKJqVIne91VffOW6-dWyNE7oHg,3715
66
66
  datachain/query/__init__.py,sha256=tv-spkjUCYamMN9ys_90scYrZ8kJ7C7d1MTYVmxGtk4,325
@@ -81,20 +81,20 @@ datachain/sql/types.py,sha256=SShudhdIpdfTKDxWDDqOajYRkTCkIgQbilA94g4i-4E,10389
81
81
  datachain/sql/utils.py,sha256=rzlJw08etivdrcuQPqNVvVWhuVSyUPUQEEc6DOhu258,818
82
82
  datachain/sql/default/__init__.py,sha256=XQ2cEZpzWiABqjV-6yYHUBGI9vN_UHxbxZENESmVAWw,45
83
83
  datachain/sql/default/base.py,sha256=h44005q3qtMc9cjWmRufWwcBr5CfK_dnvG4IrcSQs_8,536
84
- datachain/sql/functions/__init__.py,sha256=PP8XV1CC1naIu87fiExbJRpV0Rww47EcDrDIKJb_xBQ,368
85
- datachain/sql/functions/array.py,sha256=rvH27SWN9gdh_mFnp0GIiXuCrNW6n8ZbY4I_JUS-_e0,1140
84
+ datachain/sql/functions/__init__.py,sha256=Ioyy7nSetrTLVnHGcGcmZU99HxUFcx-5PFbrh2dPNH0,396
85
+ datachain/sql/functions/array.py,sha256=EB7nJSncUc1PuxlHyzU2gVhF8DuXaxpGlxb5e8X2KFY,1297
86
86
  datachain/sql/functions/conditional.py,sha256=q7YUKfunXeEldXaxgT-p5pUTcOEVU_tcQ2BJlquTRPs,207
87
87
  datachain/sql/functions/path.py,sha256=zixpERotTFP6LZ7I4TiGtyRA8kXOoZmH1yzH9oRW0mg,1294
88
88
  datachain/sql/functions/random.py,sha256=vBwEEj98VH4LjWixUCygQ5Bz1mv1nohsCG0-ZTELlVg,271
89
89
  datachain/sql/functions/string.py,sha256=hIrF1fTvlPamDtm8UMnWDcnGfbbjCsHxZXS30U2Rzxo,651
90
90
  datachain/sql/sqlite/__init__.py,sha256=TAdJX0Bg28XdqPO-QwUVKy8rg78cgMileHvMNot7d04,166
91
- datachain/sql/sqlite/base.py,sha256=Jb1csbIARjEvwbylnvgNA7ChozSyoL3CQzOGBUf8QAw,12067
91
+ datachain/sql/sqlite/base.py,sha256=LBYmXqXsVF30fbcnR55evCZHbPDCzMdGk_ogPLps63s,12236
92
92
  datachain/sql/sqlite/types.py,sha256=yzvp0sXSEoEYXs6zaYC_2YubarQoZH-MiUNXcpuEP4s,1573
93
93
  datachain/sql/sqlite/vector.py,sha256=ncW4eu2FlJhrP_CIpsvtkUabZlQdl2D5Lgwy_cbfqR0,469
94
94
  datachain/torch/__init__.py,sha256=gIS74PoEPy4TB3X6vx9nLO0Y3sLJzsA8ckn8pRWihJM,579
95
- datachain-0.2.17.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
96
- datachain-0.2.17.dist-info/METADATA,sha256=STR0-4R9NOW55GgadrPA_-fmx5-WckcwhTmyH_OgaUs,17269
97
- datachain-0.2.17.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
98
- datachain-0.2.17.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
99
- datachain-0.2.17.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
100
- datachain-0.2.17.dist-info/RECORD,,
95
+ datachain-0.2.18.dist-info/LICENSE,sha256=8DnqK5yoPI_E50bEg_zsHKZHY2HqPy4rYN338BHQaRA,11344
96
+ datachain-0.2.18.dist-info/METADATA,sha256=_wZgyu8nS5Ut_kQcIc_n9979rQcvv8fPuSIHbyCGhX0,17269
97
+ datachain-0.2.18.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
98
+ datachain-0.2.18.dist-info/entry_points.txt,sha256=0GMJS6B_KWq0m3VT98vQI2YZodAMkn4uReZ_okga9R4,49
99
+ datachain-0.2.18.dist-info/top_level.txt,sha256=lZPpdU_2jJABLNIg2kvEOBi8PtsYikbN1OdMLHk8bTg,10
100
+ datachain-0.2.18.dist-info/RECORD,,