datachain 0.34.6__py3-none-any.whl → 0.34.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of datachain might be problematic. Click here for more details.
- datachain/asyn.py +11 -12
- datachain/cache.py +5 -5
- datachain/catalog/catalog.py +75 -83
- datachain/catalog/loader.py +3 -3
- datachain/checkpoint.py +1 -2
- datachain/cli/__init__.py +2 -4
- datachain/cli/commands/datasets.py +13 -13
- datachain/cli/commands/ls.py +4 -4
- datachain/cli/commands/query.py +3 -3
- datachain/cli/commands/show.py +2 -2
- datachain/cli/parser/job.py +1 -1
- datachain/cli/parser/utils.py +1 -2
- datachain/cli/utils.py +1 -2
- datachain/client/azure.py +2 -2
- datachain/client/fsspec.py +11 -21
- datachain/client/gcs.py +3 -3
- datachain/client/http.py +4 -4
- datachain/client/local.py +4 -4
- datachain/client/s3.py +3 -3
- datachain/config.py +4 -8
- datachain/data_storage/db_engine.py +5 -5
- datachain/data_storage/metastore.py +107 -107
- datachain/data_storage/schema.py +18 -24
- datachain/data_storage/sqlite.py +21 -28
- datachain/data_storage/warehouse.py +13 -13
- datachain/dataset.py +64 -70
- datachain/delta.py +21 -18
- datachain/diff/__init__.py +13 -13
- datachain/func/aggregate.py +9 -11
- datachain/func/array.py +12 -12
- datachain/func/base.py +7 -4
- datachain/func/conditional.py +9 -13
- datachain/func/func.py +45 -42
- datachain/func/numeric.py +5 -7
- datachain/func/string.py +2 -2
- datachain/hash_utils.py +54 -81
- datachain/job.py +8 -8
- datachain/lib/arrow.py +17 -14
- datachain/lib/audio.py +6 -6
- datachain/lib/clip.py +5 -4
- datachain/lib/convert/python_to_sql.py +4 -22
- datachain/lib/convert/values_to_tuples.py +4 -9
- datachain/lib/data_model.py +20 -19
- datachain/lib/dataset_info.py +6 -6
- datachain/lib/dc/csv.py +10 -10
- datachain/lib/dc/database.py +28 -29
- datachain/lib/dc/datachain.py +98 -97
- datachain/lib/dc/datasets.py +22 -22
- datachain/lib/dc/hf.py +4 -4
- datachain/lib/dc/json.py +9 -10
- datachain/lib/dc/listings.py +5 -8
- datachain/lib/dc/pandas.py +3 -6
- datachain/lib/dc/parquet.py +5 -5
- datachain/lib/dc/records.py +5 -5
- datachain/lib/dc/storage.py +12 -12
- datachain/lib/dc/storage_pattern.py +2 -2
- datachain/lib/dc/utils.py +11 -14
- datachain/lib/dc/values.py +3 -6
- datachain/lib/file.py +26 -26
- datachain/lib/hf.py +7 -5
- datachain/lib/image.py +13 -13
- datachain/lib/listing.py +5 -5
- datachain/lib/listing_info.py +1 -2
- datachain/lib/meta_formats.py +1 -2
- datachain/lib/model_store.py +3 -3
- datachain/lib/namespaces.py +4 -6
- datachain/lib/projects.py +5 -9
- datachain/lib/pytorch.py +10 -10
- datachain/lib/settings.py +23 -23
- datachain/lib/signal_schema.py +52 -44
- datachain/lib/text.py +8 -7
- datachain/lib/udf.py +25 -17
- datachain/lib/udf_signature.py +11 -11
- datachain/lib/video.py +3 -4
- datachain/lib/webdataset.py +30 -35
- datachain/lib/webdataset_laion.py +15 -16
- datachain/listing.py +4 -4
- datachain/model/bbox.py +3 -1
- datachain/namespace.py +4 -4
- datachain/node.py +6 -6
- datachain/nodes_thread_pool.py +0 -1
- datachain/plugins.py +1 -7
- datachain/project.py +4 -4
- datachain/query/batch.py +7 -8
- datachain/query/dataset.py +80 -87
- datachain/query/dispatch.py +7 -7
- datachain/query/metrics.py +3 -4
- datachain/query/params.py +2 -3
- datachain/query/schema.py +7 -6
- datachain/query/session.py +7 -7
- datachain/query/udf.py +8 -7
- datachain/query/utils.py +3 -5
- datachain/remote/studio.py +33 -39
- datachain/script_meta.py +12 -12
- datachain/sql/sqlite/base.py +6 -9
- datachain/studio.py +30 -30
- datachain/toolkit/split.py +1 -2
- datachain/utils.py +21 -21
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/METADATA +2 -3
- datachain-0.34.7.dist-info/RECORD +173 -0
- datachain-0.34.6.dist-info/RECORD +0 -173
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/WHEEL +0 -0
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/entry_points.txt +0 -0
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/licenses/LICENSE +0 -0
- {datachain-0.34.6.dist-info → datachain-0.34.7.dist-info}/top_level.txt +0 -0
datachain/diff/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
2
|
from enum import Enum
|
|
3
|
-
from typing import TYPE_CHECKING
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
5
|
from datachain.func import case, ifelse, isnone, or_
|
|
6
6
|
from datachain.lib.signal_schema import SignalSchema
|
|
@@ -27,22 +27,22 @@ class CompareStatus(str, Enum):
|
|
|
27
27
|
def _compare( # noqa: C901, PLR0912
|
|
28
28
|
left: "DataChain",
|
|
29
29
|
right: "DataChain",
|
|
30
|
-
on:
|
|
31
|
-
right_on:
|
|
32
|
-
compare:
|
|
33
|
-
right_compare:
|
|
30
|
+
on: str | Sequence[str],
|
|
31
|
+
right_on: str | Sequence[str] | None = None,
|
|
32
|
+
compare: str | Sequence[str] | None = None,
|
|
33
|
+
right_compare: str | Sequence[str] | None = None,
|
|
34
34
|
added: bool = True,
|
|
35
35
|
deleted: bool = True,
|
|
36
36
|
modified: bool = True,
|
|
37
37
|
same: bool = True,
|
|
38
|
-
status_col:
|
|
38
|
+
status_col: str | None = None,
|
|
39
39
|
) -> "DataChain":
|
|
40
40
|
"""Comparing two chains by identifying rows that are added, deleted, modified
|
|
41
41
|
or same"""
|
|
42
42
|
rname = "right_"
|
|
43
43
|
schema = left.signals_schema # final chain must have schema from left chain
|
|
44
44
|
|
|
45
|
-
def _to_list(obj:
|
|
45
|
+
def _to_list(obj: str | Sequence[str] | None) -> list[str] | None:
|
|
46
46
|
if obj is None:
|
|
47
47
|
return None
|
|
48
48
|
return [obj] if isinstance(obj, str) else list(obj)
|
|
@@ -109,7 +109,7 @@ def _compare( # noqa: C901, PLR0912
|
|
|
109
109
|
modified_cond = or_( # type: ignore[assignment]
|
|
110
110
|
*[
|
|
111
111
|
C(c) != (C(f"{rname}{rc}") if c == rc else C(rc))
|
|
112
|
-
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
|
|
112
|
+
for c, rc in zip(compare, right_compare, strict=False) # type: ignore[arg-type]
|
|
113
113
|
]
|
|
114
114
|
)
|
|
115
115
|
|
|
@@ -133,7 +133,7 @@ def _compare( # noqa: C901, PLR0912
|
|
|
133
133
|
C(f"{rname + l_on if on == right_on else r_on}"),
|
|
134
134
|
C(l_on),
|
|
135
135
|
)
|
|
136
|
-
for l_on, r_on in zip(on, right_on) # type: ignore[arg-type]
|
|
136
|
+
for l_on, r_on in zip(on, right_on, strict=False) # type: ignore[arg-type]
|
|
137
137
|
}
|
|
138
138
|
)
|
|
139
139
|
.select_except(ldiff_col, rdiff_col)
|
|
@@ -168,10 +168,10 @@ def _compare( # noqa: C901, PLR0912
|
|
|
168
168
|
def compare_and_split(
|
|
169
169
|
left: "DataChain",
|
|
170
170
|
right: "DataChain",
|
|
171
|
-
on:
|
|
172
|
-
right_on:
|
|
173
|
-
compare:
|
|
174
|
-
right_compare:
|
|
171
|
+
on: str | Sequence[str],
|
|
172
|
+
right_on: str | Sequence[str] | None = None,
|
|
173
|
+
compare: str | Sequence[str] | None = None,
|
|
174
|
+
right_compare: str | Sequence[str] | None = None,
|
|
175
175
|
added: bool = True,
|
|
176
176
|
deleted: bool = True,
|
|
177
177
|
modified: bool = True,
|
datachain/func/aggregate.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional, Union
|
|
2
|
-
|
|
3
1
|
from sqlalchemy import func as sa_func
|
|
4
2
|
|
|
5
3
|
from datachain.query.schema import Column
|
|
@@ -8,7 +6,7 @@ from datachain.sql.functions import aggregate
|
|
|
8
6
|
from .func import Func
|
|
9
7
|
|
|
10
8
|
|
|
11
|
-
def count(col:
|
|
9
|
+
def count(col: str | Column | None = None) -> Func:
|
|
12
10
|
"""
|
|
13
11
|
Returns a COUNT aggregate SQL function for the specified column.
|
|
14
12
|
|
|
@@ -44,7 +42,7 @@ def count(col: Optional[Union[str, Column]] = None) -> Func:
|
|
|
44
42
|
)
|
|
45
43
|
|
|
46
44
|
|
|
47
|
-
def sum(col:
|
|
45
|
+
def sum(col: str | Column) -> Func:
|
|
48
46
|
"""
|
|
49
47
|
Returns the SUM aggregate SQL function for the specified column.
|
|
50
48
|
|
|
@@ -74,7 +72,7 @@ def sum(col: Union[str, Column]) -> Func:
|
|
|
74
72
|
return Func("sum", inner=sa_func.sum, cols=[col])
|
|
75
73
|
|
|
76
74
|
|
|
77
|
-
def avg(col:
|
|
75
|
+
def avg(col: str | Column) -> Func:
|
|
78
76
|
"""
|
|
79
77
|
Returns the AVG aggregate SQL function for the specified column.
|
|
80
78
|
|
|
@@ -104,7 +102,7 @@ def avg(col: Union[str, Column]) -> Func:
|
|
|
104
102
|
return Func("avg", inner=aggregate.avg, cols=[col], result_type=float)
|
|
105
103
|
|
|
106
104
|
|
|
107
|
-
def min(col:
|
|
105
|
+
def min(col: str | Column) -> Func:
|
|
108
106
|
"""
|
|
109
107
|
Returns the MIN aggregate SQL function for the specified column.
|
|
110
108
|
|
|
@@ -134,7 +132,7 @@ def min(col: Union[str, Column]) -> Func:
|
|
|
134
132
|
return Func("min", inner=sa_func.min, cols=[col])
|
|
135
133
|
|
|
136
134
|
|
|
137
|
-
def max(col:
|
|
135
|
+
def max(col: str | Column) -> Func:
|
|
138
136
|
"""
|
|
139
137
|
Returns the MAX aggregate SQL function for the given column name.
|
|
140
138
|
|
|
@@ -164,7 +162,7 @@ def max(col: Union[str, Column]) -> Func:
|
|
|
164
162
|
return Func("max", inner=sa_func.max, cols=[col])
|
|
165
163
|
|
|
166
164
|
|
|
167
|
-
def any_value(col:
|
|
165
|
+
def any_value(col: str | Column) -> Func:
|
|
168
166
|
"""
|
|
169
167
|
Returns the ANY_VALUE aggregate SQL function for the given column name.
|
|
170
168
|
|
|
@@ -198,7 +196,7 @@ def any_value(col: Union[str, Column]) -> Func:
|
|
|
198
196
|
return Func("any_value", inner=aggregate.any_value, cols=[col])
|
|
199
197
|
|
|
200
198
|
|
|
201
|
-
def collect(col:
|
|
199
|
+
def collect(col: str | Column) -> Func:
|
|
202
200
|
"""
|
|
203
201
|
Returns the COLLECT aggregate SQL function for the given column name.
|
|
204
202
|
|
|
@@ -229,7 +227,7 @@ def collect(col: Union[str, Column]) -> Func:
|
|
|
229
227
|
return Func("collect", inner=aggregate.collect, cols=[col], is_array=True)
|
|
230
228
|
|
|
231
229
|
|
|
232
|
-
def concat(col:
|
|
230
|
+
def concat(col: str | Column, separator="") -> Func:
|
|
233
231
|
"""
|
|
234
232
|
Returns the CONCAT aggregate SQL function for the given column name.
|
|
235
233
|
|
|
@@ -348,7 +346,7 @@ def dense_rank() -> Func:
|
|
|
348
346
|
return Func("dense_rank", inner=sa_func.dense_rank, result_type=int, is_window=True)
|
|
349
347
|
|
|
350
348
|
|
|
351
|
-
def first(col:
|
|
349
|
+
def first(col: str | Column) -> Func:
|
|
352
350
|
"""
|
|
353
351
|
Returns the FIRST_VALUE window function for SQL queries.
|
|
354
352
|
|
datachain/func/array.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
from datachain.query.schema import Column
|
|
5
5
|
from datachain.sql.functions import array
|
|
@@ -7,7 +7,7 @@ from datachain.sql.functions import array
|
|
|
7
7
|
from .func import Func
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def cosine_distance(*args:
|
|
10
|
+
def cosine_distance(*args: str | Column | Func | Sequence) -> Func:
|
|
11
11
|
"""
|
|
12
12
|
Returns the cosine distance between two vectors.
|
|
13
13
|
|
|
@@ -62,7 +62,7 @@ def cosine_distance(*args: Union[str, Column, Func, Sequence]) -> Func:
|
|
|
62
62
|
)
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
def euclidean_distance(*args:
|
|
65
|
+
def euclidean_distance(*args: str | Column | Func | Sequence) -> Func:
|
|
66
66
|
"""
|
|
67
67
|
Returns the Euclidean distance between two vectors.
|
|
68
68
|
|
|
@@ -115,7 +115,7 @@ def euclidean_distance(*args: Union[str, Column, Func, Sequence]) -> Func:
|
|
|
115
115
|
)
|
|
116
116
|
|
|
117
117
|
|
|
118
|
-
def length(arg:
|
|
118
|
+
def length(arg: str | Column | Func | Sequence) -> Func:
|
|
119
119
|
"""
|
|
120
120
|
Returns the length of the array.
|
|
121
121
|
|
|
@@ -151,7 +151,7 @@ def length(arg: Union[str, Column, Func, Sequence]) -> Func:
|
|
|
151
151
|
return Func("length", inner=array.length, cols=cols, args=args, result_type=int)
|
|
152
152
|
|
|
153
153
|
|
|
154
|
-
def contains(arr:
|
|
154
|
+
def contains(arr: str | Column | Func | Sequence, elem: Any) -> Func:
|
|
155
155
|
"""
|
|
156
156
|
Checks whether the array contains the specified element.
|
|
157
157
|
|
|
@@ -196,9 +196,9 @@ def contains(arr: Union[str, Column, Func, Sequence], elem: Any) -> Func:
|
|
|
196
196
|
|
|
197
197
|
|
|
198
198
|
def slice(
|
|
199
|
-
arr:
|
|
199
|
+
arr: str | Column | Func | Sequence,
|
|
200
200
|
offset: int,
|
|
201
|
-
length:
|
|
201
|
+
length: int | None = None,
|
|
202
202
|
) -> Func:
|
|
203
203
|
"""
|
|
204
204
|
Returns a slice of the array starting from the specified offset.
|
|
@@ -272,7 +272,7 @@ def slice(
|
|
|
272
272
|
|
|
273
273
|
|
|
274
274
|
def join(
|
|
275
|
-
arr:
|
|
275
|
+
arr: str | Column | Func | Sequence,
|
|
276
276
|
sep: str = "",
|
|
277
277
|
) -> Func:
|
|
278
278
|
"""
|
|
@@ -322,7 +322,7 @@ def join(
|
|
|
322
322
|
)
|
|
323
323
|
|
|
324
324
|
|
|
325
|
-
def get_element(arg:
|
|
325
|
+
def get_element(arg: str | Column | Func | Sequence, index: int) -> Func:
|
|
326
326
|
"""
|
|
327
327
|
Returns the element at the given index from the array.
|
|
328
328
|
If the index is out of bounds, it returns None or columns default value.
|
|
@@ -359,8 +359,8 @@ def get_element(arg: Union[str, Column, Func, Sequence], index: int) -> Func:
|
|
|
359
359
|
return str # if the array is empty, return str as default type
|
|
360
360
|
return None
|
|
361
361
|
|
|
362
|
-
cols:
|
|
363
|
-
args:
|
|
362
|
+
cols: str | Column | Func | Sequence | None
|
|
363
|
+
args: str | Column | Func | Sequence | int
|
|
364
364
|
|
|
365
365
|
if isinstance(arg, (str, Column, Func)):
|
|
366
366
|
cols = [arg]
|
|
@@ -379,7 +379,7 @@ def get_element(arg: Union[str, Column, Func, Sequence], index: int) -> Func:
|
|
|
379
379
|
)
|
|
380
380
|
|
|
381
381
|
|
|
382
|
-
def sip_hash_64(arg:
|
|
382
|
+
def sip_hash_64(arg: str | Column | Func | Sequence) -> Func:
|
|
383
383
|
"""
|
|
384
384
|
Returns the SipHash-64 hash of the array.
|
|
385
385
|
|
datachain/func/base.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from abc import ABCMeta, abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
3
4
|
|
|
4
5
|
if TYPE_CHECKING:
|
|
5
6
|
from sqlalchemy import TableClause
|
|
@@ -12,12 +13,14 @@ class Function:
|
|
|
12
13
|
__metaclass__ = ABCMeta
|
|
13
14
|
|
|
14
15
|
name: str
|
|
16
|
+
cols: Sequence
|
|
17
|
+
args: Sequence
|
|
15
18
|
|
|
16
19
|
@abstractmethod
|
|
17
20
|
def get_column(
|
|
18
21
|
self,
|
|
19
|
-
signals_schema:
|
|
20
|
-
label:
|
|
21
|
-
table:
|
|
22
|
+
signals_schema: "SignalSchema | None" = None,
|
|
23
|
+
label: str | None = None,
|
|
24
|
+
table: "TableClause | None" = None,
|
|
22
25
|
) -> "Column":
|
|
23
26
|
pass
|
datachain/func/conditional.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional, Union
|
|
2
|
-
|
|
3
1
|
from sqlalchemy import ColumnElement
|
|
4
2
|
from sqlalchemy import and_ as sql_and
|
|
5
3
|
from sqlalchemy import case as sql_case
|
|
@@ -12,10 +10,10 @@ from datachain.sql.functions import conditional
|
|
|
12
10
|
|
|
13
11
|
from .func import Func
|
|
14
12
|
|
|
15
|
-
CaseT =
|
|
13
|
+
CaseT = int | float | complex | bool | str | Func | ColumnElement
|
|
16
14
|
|
|
17
15
|
|
|
18
|
-
def greatest(*args:
|
|
16
|
+
def greatest(*args: str | Column | Func | float) -> Func:
|
|
19
17
|
"""
|
|
20
18
|
Returns the greatest (largest) value from the given input values.
|
|
21
19
|
|
|
@@ -56,7 +54,7 @@ def greatest(*args: Union[str, Column, Func, float]) -> Func:
|
|
|
56
54
|
)
|
|
57
55
|
|
|
58
56
|
|
|
59
|
-
def least(*args:
|
|
57
|
+
def least(*args: str | Column | Func | float) -> Func:
|
|
60
58
|
"""
|
|
61
59
|
Returns the least (smallest) value from the given input values.
|
|
62
60
|
|
|
@@ -94,7 +92,7 @@ def least(*args: Union[str, Column, Func, float]) -> Func:
|
|
|
94
92
|
|
|
95
93
|
|
|
96
94
|
def case(
|
|
97
|
-
*args: tuple[
|
|
95
|
+
*args: tuple[ColumnElement | Func | bool, CaseT], else_: CaseT | None = None
|
|
98
96
|
) -> Func:
|
|
99
97
|
"""
|
|
100
98
|
Returns a case expression that evaluates a list of conditions and returns
|
|
@@ -163,9 +161,7 @@ def case(
|
|
|
163
161
|
return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_)
|
|
164
162
|
|
|
165
163
|
|
|
166
|
-
def ifelse(
|
|
167
|
-
condition: Union[ColumnElement, Func], if_val: CaseT, else_val: CaseT
|
|
168
|
-
) -> Func:
|
|
164
|
+
def ifelse(condition: ColumnElement | Func, if_val: CaseT, else_val: CaseT) -> Func:
|
|
169
165
|
"""
|
|
170
166
|
Returns an if-else expression that evaluates a condition and returns one
|
|
171
167
|
of two values based on the result. Values can be Python primitives
|
|
@@ -193,7 +189,7 @@ def ifelse(
|
|
|
193
189
|
return case((condition, if_val), else_=else_val)
|
|
194
190
|
|
|
195
191
|
|
|
196
|
-
def isnone(col:
|
|
192
|
+
def isnone(col: str | ColumnElement) -> Func:
|
|
197
193
|
"""
|
|
198
194
|
Returns a function that checks if the column value is `None` (NULL in DB).
|
|
199
195
|
|
|
@@ -221,7 +217,7 @@ def isnone(col: Union[str, ColumnElement]) -> Func:
|
|
|
221
217
|
return case((col.is_(None) if col is not None else True, True), else_=False)
|
|
222
218
|
|
|
223
219
|
|
|
224
|
-
def or_(*args:
|
|
220
|
+
def or_(*args: ColumnElement | Func) -> Func:
|
|
225
221
|
"""
|
|
226
222
|
Returns the function that produces conjunction of expressions joined by OR
|
|
227
223
|
logical operator.
|
|
@@ -256,7 +252,7 @@ def or_(*args: Union[ColumnElement, Func]) -> Func:
|
|
|
256
252
|
return Func("or", inner=sql_or, cols=cols, args=func_args, result_type=bool)
|
|
257
253
|
|
|
258
254
|
|
|
259
|
-
def and_(*args:
|
|
255
|
+
def and_(*args: ColumnElement | Func) -> Func:
|
|
260
256
|
"""
|
|
261
257
|
Returns the function that produces conjunction of expressions joined by AND
|
|
262
258
|
logical operator.
|
|
@@ -291,7 +287,7 @@ def and_(*args: Union[ColumnElement, Func]) -> Func:
|
|
|
291
287
|
return Func("and", inner=sql_and, cols=cols, args=func_args, result_type=bool)
|
|
292
288
|
|
|
293
289
|
|
|
294
|
-
def not_(arg:
|
|
290
|
+
def not_(arg: ColumnElement | Func) -> Func:
|
|
295
291
|
"""
|
|
296
292
|
Returns the function that produces NOT of the given expressions.
|
|
297
293
|
|
datachain/func/func.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from collections.abc import Sequence
|
|
3
|
-
from typing import TYPE_CHECKING, Any,
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Union, get_args, get_origin
|
|
4
4
|
|
|
5
5
|
from sqlalchemy import BindParameter, Case, ColumnElement, Integer, cast, desc
|
|
6
6
|
from sqlalchemy.sql import func as sa_func
|
|
@@ -22,26 +22,29 @@ if TYPE_CHECKING:
|
|
|
22
22
|
from .window import Window
|
|
23
23
|
|
|
24
24
|
|
|
25
|
-
ColT = Union[str, Column, ColumnElement, "Func"
|
|
25
|
+
ColT = Union[str, tuple, Column, ColumnElement, "Func"]
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class Func(Function): # noqa: PLW1641
|
|
29
29
|
"""Represents a function to be applied to a column in a SQL query."""
|
|
30
30
|
|
|
31
|
+
cols: Sequence[ColT]
|
|
32
|
+
args: Sequence[Any]
|
|
33
|
+
|
|
31
34
|
def __init__(
|
|
32
35
|
self,
|
|
33
36
|
name: str,
|
|
34
37
|
inner: Callable,
|
|
35
|
-
cols:
|
|
36
|
-
args:
|
|
37
|
-
kwargs:
|
|
38
|
-
result_type:
|
|
39
|
-
type_from_args:
|
|
38
|
+
cols: Sequence[ColT] | None = None,
|
|
39
|
+
args: Sequence[Any] | None = None,
|
|
40
|
+
kwargs: dict[str, Any] | None = None,
|
|
41
|
+
result_type: "DataType | None" = None,
|
|
42
|
+
type_from_args: Callable[..., "DataType"] | None = None,
|
|
40
43
|
is_array: bool = False,
|
|
41
44
|
from_array: bool = False,
|
|
42
45
|
is_window: bool = False,
|
|
43
|
-
window:
|
|
44
|
-
label:
|
|
46
|
+
window: "Window | None" = None,
|
|
47
|
+
label: str | None = None,
|
|
45
48
|
) -> None:
|
|
46
49
|
self.name = name
|
|
47
50
|
self.inner = inner
|
|
@@ -95,7 +98,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
95
98
|
else []
|
|
96
99
|
)
|
|
97
100
|
|
|
98
|
-
def _db_col_type(self, signals_schema: "SignalSchema") ->
|
|
101
|
+
def _db_col_type(self, signals_schema: "SignalSchema") -> "DataType | None":
|
|
99
102
|
if not self._db_cols:
|
|
100
103
|
return None
|
|
101
104
|
|
|
@@ -125,51 +128,51 @@ class Func(Function): # noqa: PLW1641
|
|
|
125
128
|
|
|
126
129
|
return list[col_type] if self.is_array else col_type # type: ignore[valid-type]
|
|
127
130
|
|
|
128
|
-
def __add__(self, other:
|
|
131
|
+
def __add__(self, other: ColT | float) -> "Func":
|
|
129
132
|
if isinstance(other, (int, float)):
|
|
130
133
|
return Func("add", lambda a: a + other, [self])
|
|
131
134
|
return Func("add", lambda a1, a2: a1 + a2, [self, other])
|
|
132
135
|
|
|
133
|
-
def __radd__(self, other:
|
|
136
|
+
def __radd__(self, other: ColT | float) -> "Func":
|
|
134
137
|
if isinstance(other, (int, float)):
|
|
135
138
|
return Func("add", lambda a: other + a, [self])
|
|
136
139
|
return Func("add", lambda a1, a2: a1 + a2, [other, self])
|
|
137
140
|
|
|
138
|
-
def __sub__(self, other:
|
|
141
|
+
def __sub__(self, other: ColT | float) -> "Func":
|
|
139
142
|
if isinstance(other, (int, float)):
|
|
140
143
|
return Func("sub", lambda a: a - other, [self])
|
|
141
144
|
return Func("sub", lambda a1, a2: a1 - a2, [self, other])
|
|
142
145
|
|
|
143
|
-
def __rsub__(self, other:
|
|
146
|
+
def __rsub__(self, other: ColT | float) -> "Func":
|
|
144
147
|
if isinstance(other, (int, float)):
|
|
145
148
|
return Func("sub", lambda a: other - a, [self])
|
|
146
149
|
return Func("sub", lambda a1, a2: a1 - a2, [other, self])
|
|
147
150
|
|
|
148
|
-
def __mul__(self, other:
|
|
151
|
+
def __mul__(self, other: ColT | float) -> "Func":
|
|
149
152
|
if isinstance(other, (int, float)):
|
|
150
153
|
return Func("mul", lambda a: a * other, [self])
|
|
151
154
|
return Func("mul", lambda a1, a2: a1 * a2, [self, other])
|
|
152
155
|
|
|
153
|
-
def __rmul__(self, other:
|
|
156
|
+
def __rmul__(self, other: ColT | float) -> "Func":
|
|
154
157
|
if isinstance(other, (int, float)):
|
|
155
158
|
return Func("mul", lambda a: other * a, [self])
|
|
156
159
|
return Func("mul", lambda a1, a2: a1 * a2, [other, self])
|
|
157
160
|
|
|
158
|
-
def __truediv__(self, other:
|
|
161
|
+
def __truediv__(self, other: ColT | float) -> "Func":
|
|
159
162
|
if isinstance(other, (int, float)):
|
|
160
163
|
return Func("div", lambda a: _truediv(a, other), [self], result_type=float)
|
|
161
164
|
return Func(
|
|
162
165
|
"div", lambda a1, a2: _truediv(a1, a2), [self, other], result_type=float
|
|
163
166
|
)
|
|
164
167
|
|
|
165
|
-
def __rtruediv__(self, other:
|
|
168
|
+
def __rtruediv__(self, other: ColT | float) -> "Func":
|
|
166
169
|
if isinstance(other, (int, float)):
|
|
167
170
|
return Func("div", lambda a: _truediv(other, a), [self], result_type=float)
|
|
168
171
|
return Func(
|
|
169
172
|
"div", lambda a1, a2: _truediv(a1, a2), [other, self], result_type=float
|
|
170
173
|
)
|
|
171
174
|
|
|
172
|
-
def __floordiv__(self, other:
|
|
175
|
+
def __floordiv__(self, other: ColT | float) -> "Func":
|
|
173
176
|
if isinstance(other, (int, float)):
|
|
174
177
|
return Func(
|
|
175
178
|
"floordiv", lambda a: _floordiv(a, other), [self], result_type=int
|
|
@@ -178,7 +181,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
178
181
|
"floordiv", lambda a1, a2: _floordiv(a1, a2), [self, other], result_type=int
|
|
179
182
|
)
|
|
180
183
|
|
|
181
|
-
def __rfloordiv__(self, other:
|
|
184
|
+
def __rfloordiv__(self, other: ColT | float) -> "Func":
|
|
182
185
|
if isinstance(other, (int, float)):
|
|
183
186
|
return Func(
|
|
184
187
|
"floordiv", lambda a: _floordiv(other, a), [self], result_type=int
|
|
@@ -187,17 +190,17 @@ class Func(Function): # noqa: PLW1641
|
|
|
187
190
|
"floordiv", lambda a1, a2: _floordiv(a1, a2), [other, self], result_type=int
|
|
188
191
|
)
|
|
189
192
|
|
|
190
|
-
def __mod__(self, other:
|
|
193
|
+
def __mod__(self, other: ColT | float) -> "Func":
|
|
191
194
|
if isinstance(other, (int, float)):
|
|
192
195
|
return Func("mod", lambda a: a % other, [self], result_type=int)
|
|
193
196
|
return Func("mod", lambda a1, a2: a1 % a2, [self, other], result_type=int)
|
|
194
197
|
|
|
195
|
-
def __rmod__(self, other:
|
|
198
|
+
def __rmod__(self, other: ColT | float) -> "Func":
|
|
196
199
|
if isinstance(other, (int, float)):
|
|
197
200
|
return Func("mod", lambda a: other % a, [self], result_type=int)
|
|
198
201
|
return Func("mod", lambda a1, a2: a1 % a2, [other, self], result_type=int)
|
|
199
202
|
|
|
200
|
-
def __and__(self, other:
|
|
203
|
+
def __and__(self, other: ColT | float) -> "Func":
|
|
201
204
|
if isinstance(other, (int, float)):
|
|
202
205
|
return Func(
|
|
203
206
|
"and", lambda a: numeric.bit_and(a, other), [self], result_type=int
|
|
@@ -209,7 +212,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
209
212
|
result_type=int,
|
|
210
213
|
)
|
|
211
214
|
|
|
212
|
-
def __rand__(self, other:
|
|
215
|
+
def __rand__(self, other: ColT | float) -> "Func":
|
|
213
216
|
if isinstance(other, (int, float)):
|
|
214
217
|
return Func(
|
|
215
218
|
"and", lambda a: numeric.bit_and(other, a), [self], result_type=int
|
|
@@ -221,7 +224,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
221
224
|
result_type=int,
|
|
222
225
|
)
|
|
223
226
|
|
|
224
|
-
def __or__(self, other:
|
|
227
|
+
def __or__(self, other: ColT | float) -> "Func":
|
|
225
228
|
if isinstance(other, (int, float)):
|
|
226
229
|
return Func(
|
|
227
230
|
"or", lambda a: numeric.bit_or(a, other), [self], result_type=int
|
|
@@ -230,7 +233,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
230
233
|
"or", lambda a1, a2: numeric.bit_or(a1, a2), [self, other], result_type=int
|
|
231
234
|
)
|
|
232
235
|
|
|
233
|
-
def __ror__(self, other:
|
|
236
|
+
def __ror__(self, other: ColT | float) -> "Func":
|
|
234
237
|
if isinstance(other, (int, float)):
|
|
235
238
|
return Func(
|
|
236
239
|
"or", lambda a: numeric.bit_or(other, a), [self], result_type=int
|
|
@@ -239,7 +242,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
239
242
|
"or", lambda a1, a2: numeric.bit_or(a1, a2), [other, self], result_type=int
|
|
240
243
|
)
|
|
241
244
|
|
|
242
|
-
def __xor__(self, other:
|
|
245
|
+
def __xor__(self, other: ColT | float) -> "Func":
|
|
243
246
|
if isinstance(other, (int, float)):
|
|
244
247
|
return Func(
|
|
245
248
|
"xor", lambda a: numeric.bit_xor(a, other), [self], result_type=int
|
|
@@ -251,7 +254,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
251
254
|
result_type=int,
|
|
252
255
|
)
|
|
253
256
|
|
|
254
|
-
def __rxor__(self, other:
|
|
257
|
+
def __rxor__(self, other: ColT | float) -> "Func":
|
|
255
258
|
if isinstance(other, (int, float)):
|
|
256
259
|
return Func(
|
|
257
260
|
"xor", lambda a: numeric.bit_xor(other, a), [self], result_type=int
|
|
@@ -263,7 +266,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
263
266
|
result_type=int,
|
|
264
267
|
)
|
|
265
268
|
|
|
266
|
-
def __rshift__(self, other:
|
|
269
|
+
def __rshift__(self, other: ColT | float) -> "Func":
|
|
267
270
|
if isinstance(other, (int, float)):
|
|
268
271
|
return Func(
|
|
269
272
|
"rshift",
|
|
@@ -278,7 +281,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
278
281
|
result_type=int,
|
|
279
282
|
)
|
|
280
283
|
|
|
281
|
-
def __rrshift__(self, other:
|
|
284
|
+
def __rrshift__(self, other: ColT | float) -> "Func":
|
|
282
285
|
if isinstance(other, (int, float)):
|
|
283
286
|
return Func(
|
|
284
287
|
"rshift",
|
|
@@ -293,7 +296,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
293
296
|
result_type=int,
|
|
294
297
|
)
|
|
295
298
|
|
|
296
|
-
def __lshift__(self, other:
|
|
299
|
+
def __lshift__(self, other: ColT | float) -> "Func":
|
|
297
300
|
if isinstance(other, (int, float)):
|
|
298
301
|
return Func(
|
|
299
302
|
"lshift",
|
|
@@ -308,7 +311,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
308
311
|
result_type=int,
|
|
309
312
|
)
|
|
310
313
|
|
|
311
|
-
def __rlshift__(self, other:
|
|
314
|
+
def __rlshift__(self, other: ColT | float) -> "Func":
|
|
312
315
|
if isinstance(other, (int, float)):
|
|
313
316
|
return Func(
|
|
314
317
|
"lshift",
|
|
@@ -323,12 +326,12 @@ class Func(Function): # noqa: PLW1641
|
|
|
323
326
|
result_type=int,
|
|
324
327
|
)
|
|
325
328
|
|
|
326
|
-
def __lt__(self, other:
|
|
329
|
+
def __lt__(self, other: ColT | float) -> "Func":
|
|
327
330
|
if isinstance(other, (int, float)):
|
|
328
331
|
return Func("lt", lambda a: a < other, [self], result_type=bool)
|
|
329
332
|
return Func("lt", lambda a1, a2: a1 < a2, [self, other], result_type=bool)
|
|
330
333
|
|
|
331
|
-
def __le__(self, other:
|
|
334
|
+
def __le__(self, other: ColT | float) -> "Func":
|
|
332
335
|
if isinstance(other, (int, float)):
|
|
333
336
|
return Func("le", lambda a: a <= other, [self], result_type=bool)
|
|
334
337
|
return Func("le", lambda a1, a2: a1 <= a2, [self, other], result_type=bool)
|
|
@@ -343,12 +346,12 @@ class Func(Function): # noqa: PLW1641
|
|
|
343
346
|
return Func("ne", lambda a: a != other, [self], result_type=bool)
|
|
344
347
|
return Func("ne", lambda a1, a2: a1 != a2, [self, other], result_type=bool)
|
|
345
348
|
|
|
346
|
-
def __gt__(self, other:
|
|
349
|
+
def __gt__(self, other: ColT | float) -> "Func":
|
|
347
350
|
if isinstance(other, (int, float)):
|
|
348
351
|
return Func("gt", lambda a: a > other, [self], result_type=bool)
|
|
349
352
|
return Func("gt", lambda a1, a2: a1 > a2, [self, other], result_type=bool)
|
|
350
353
|
|
|
351
|
-
def __ge__(self, other:
|
|
354
|
+
def __ge__(self, other: ColT | float) -> "Func":
|
|
352
355
|
if isinstance(other, (int, float)):
|
|
353
356
|
return Func("ge", lambda a: a >= other, [self], result_type=bool)
|
|
354
357
|
return Func("ge", lambda a1, a2: a1 >= a2, [self, other], result_type=bool)
|
|
@@ -369,7 +372,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
369
372
|
label,
|
|
370
373
|
)
|
|
371
374
|
|
|
372
|
-
def get_col_name(self, label:
|
|
375
|
+
def get_col_name(self, label: str | None = None) -> str:
|
|
373
376
|
if label:
|
|
374
377
|
return label
|
|
375
378
|
if self.col_label:
|
|
@@ -384,7 +387,7 @@ class Func(Function): # noqa: PLW1641
|
|
|
384
387
|
return self.name
|
|
385
388
|
|
|
386
389
|
def get_result_type(
|
|
387
|
-
self, signals_schema:
|
|
390
|
+
self, signals_schema: "SignalSchema | None" = None
|
|
388
391
|
) -> "DataType":
|
|
389
392
|
if self.result_type:
|
|
390
393
|
return self.result_type
|
|
@@ -408,9 +411,9 @@ class Func(Function): # noqa: PLW1641
|
|
|
408
411
|
|
|
409
412
|
def get_column(
|
|
410
413
|
self,
|
|
411
|
-
signals_schema:
|
|
412
|
-
label:
|
|
413
|
-
table:
|
|
414
|
+
signals_schema: "SignalSchema | None" = None,
|
|
415
|
+
label: str | None = None,
|
|
416
|
+
table: "TableClause | None" = None,
|
|
414
417
|
) -> Column:
|
|
415
418
|
col_type = self.get_result_type(signals_schema)
|
|
416
419
|
sql_type = python_to_sql(col_type)
|